# +~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~ #  
#
#' @title  Classify all tweets using a pre-trained ensemble classifier
#' @author Hauke Licht
#
# +~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~ #

# setup ----

# load packages
library(readr)
library(dplyr)
library(dtplyr)
library(data.table)
library(tidyr)
library(lubridate)
library(purrr)
library(stringr)
library(stringi)
library(cld2)

base_path <- file.path(".")
data_path <- file.path(base_path, "data")
helpers_path <- file.path(base_path, "code", "helpers")

# load political tweet classification helper functions
dir <- file.path(helpers_path, "politicaltweets")
tweet_helper_funs <- list.files(dir, pattern = "\\.R$", full.names = TRUE)
invisible(map(tweet_helper_funs, source))

# load political tweet classification helper data
tweet_helper_data <- list.files(dir, pattern = "\\.rda$", full.names = TRUE)
invisible(map(tweet_helper_data, load, envir = .GlobalEnv))

# load pre-trained classifier ('caretEnesemle' object)
political_classifier <- read_rds(file.path(data_path, "fits", "political_classifier.rds"))

# load independent component (IC) model fitted on (old) tweets 
ic300 <- read_rds(file.path(data_path, "fits", "laser_embeddings_ica.rds")) 

# old tweets ----

# note: need to read from older file because we updated the data in July 2023 
#  to include new accounts but the sampling for human annotation etc. was based
#  on this older data base

## create tweet features ----

df <- read_rds(file.path(data_path, "input", "parl_party_tweets_translated_2020-04-14.rds"))
df$display_text_width <- as.integer(df$display_text_width)

# keep for later
tweet_ids <- with(df, paste(user_id, status_id, sep = "_"))

# this going to take some minutes to finish (approx. 7 minutes)
tweet_features <- create_tweet_features(x = df, 20000, .as.data.table = TRUE)
dim(tweet_features)
rm(df); gc()

# check that all complete:
detect_missings(as_tibble(select(tweet_features, !!training.features$colname[1:25])))

# get independent component representations of tweets' LASER embeddings
tweet_laser_ics <- cbind(id = rownames(ic300$S), as_tibble(ic300$S))
tweet_laser_ics <- separate(tweet_laser_ics, id, c("user_id", "status_id"), sep = "_")

## predict ----

country_splits <- as_tibble(tweet_features) %>% 
  left_join(tweet_laser_ics) %>% 
  split(.$country_iso3c)

preds <- map2_dfr(
    .x = country_splits
    , .y = classification.thresholds[names(country_splits)]
    , function(.x, .y) {
      
      message("Processing country ", unique(.x$country_iso3c))
      
      pred_ <- suppressMessages(classify_tweets(
        x = select(.x, !!training.features$colname)
        , model = political_classifier
        , classification.threshold = .y
        , verbose = FALSE
        , .add = FALSE
        , se = TRUE
      ))
      
      return(
        bind_cols(
          select(.x, country_iso3c, party_id, party_name_short, user_id, status_id, created_at, text, text_en)
          , pred_
        )
      )
    }
  )

preds <- rename_at(preds, 9:12, ~gsub("yes|class", "political", .))

preds_old <- mutate(preds, is_en = cld2::detect_language(text_en) == "en")

rm(tweet_features, tweet_laser_ics, country_splits); gc()

# "new" tweets (collected in June 2023) ----

# df <- read_rds(file.path(data_path, "input", "all_party_tweets.rds"))
df <- read_rds("data/_tmp/tweets/tweets_rtweet_format.rds")

## create tweet features ----

df$created_at <- as.character(df$created_at)
nrow(df)

# this going to take some minutes to finish
tweet_features <- create_tweet_features(x = df, .as.data.table = TRUE)

tweet_ids <- with(tweet_features, paste(user_id, status_id, sep = "_"))

# check that all complete:
detect_missings(as_tibble(select(tweet_features, !!training.features$colname[1:25])))

rm(df); gc()

## get LASER embeddings' representations in independent component space ----

fp <- file.path(data_path, "intermediate", "embeddings", "tweet_laser_embeddings.tab")
tem <- fread(fp, sep = "\t", header = T)

tem <- tem[tem$id %in% tweet_ids, ]
gc()
tweet_ids <- tem$id

# get only embeddings
X <- as.matrix(tem[, .SD, .SDcols = patterns("e\\d{4}")])
dim(X)
rm(tem); gc()

# de-mean (using means in fitted object obtained from 'training data')
X[] <- X-ic300$X.means

# get indpendent components of tweets' LASER embeddings (XKW = S)
tweet_laser_ics <- X %*% ic300$K %*% ic300$W
rm(X); gc()

tweet_laser_ics <- bind_cols(id = tweet_ids, as_tibble(tweet_laser_ics))
tweet_laser_ics <- separate(tweet_laser_ics, id, c("user_id", "status_id"), sep = "_")

## get tweet metadata ----

tweets <- read_rds(file.path("replication", "data", "input", "all_party_tweets.rds"))

tweet_laser_ics <- tweets %>% 
  select(country_iso3c, party_id, party_name_short, user_id, status_id) %>% 
  inner_join(tweet_laser_ics)

rm(tweets); gc()

## predict ----

country_splits <- as_tibble(tweet_features) %>% 
  inner_join(tweet_laser_ics) %>% 
  split(.$country_iso3c)

preds <- map2_dfr(
  .x = country_splits
  , .y = classification.thresholds[names(country_splits)]
  , function(.x, .y) {
    
    message("Processing country ", unique(.x$country_iso3c))
    
    pred_ <- suppressMessages(classify_tweets(
      x = select(.x, !!training.features$colname)
      , model = political_classifier
      , classification.threshold = .y
      , verbose = FALSE
      , .add = FALSE
      , se = TRUE
    ))
    
    return(
      bind_cols(
        select(.x, country_iso3c, party_id, party_name_short, user_id, status_id, created_at, text)
        , pred_
      )
    )
  }
)

preds <- rename_at(preds, 8:11, ~gsub("yes|class", "political", .))
preds <- mutate(preds, created_at = ymd_hms(created_at))

rm(country_splits, tweet_features, tweet_laser_ics, ic300); gc()

# save to disk ----

out <- bind_rows(
  "no" = preds_old, 
  "yes" = preds,
  .id = "collected_posthoc"
)

fp <- file.path(data_path, "input", "all_tweets_classified_political.rds")
if (!file.exists(fp))
  write_rds(out, fp)
